iT邦幫忙

2023 iThome 鐵人賽

DAY 30
0
AI & Data

紮實的ML機器學習原理~打造你對資料使用sklearn的靈敏度系列 第 30

DAY 30 「Kaggle Fashion MNIST數據集」來做CNN模型融合10種不同的服裝分類啦~

  • 分享至 

  • xImage
  •  

Kaggle Fashion MNIST數據集分類

使用卷積神經網絡 (CNN) 作為基模型應用模型融合方法來提升性能~

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, Flatten, Dropout

# 加載數據
train_data = pd.read_csv('fashion-mnist_train.csv')
test_data = pd.read_csv('fashion-mnist_test.csv')

# 提取特征和目標
X = train_data.iloc[:, 1:].values.reshape(-1, 28, 28, 1) / 255.0  # 將像素值歸一化到 [0, 1]
y = train_data.iloc[:, 0].values

# 劃分訓練集和驗證集
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# 定義 CNN 模型
model = Sequential([
    Conv2D(32, (3,3), activation='relu', input_shape=(28, 28, 1)),
    MaxPooling2D((2, 2)),
    Flatten(),
    Dense(128, activation='relu'),
    Dropout(0.5),
    Dense(10, activation='softmax')
])

# 編譯模型
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# 訓練模型
model.fit(X_train, y_train, epochs=5, validation_data=(X_val, y_val))

# 使用簡單平均法進行模型融合
def ensemble_predict(models, X):
    preds = [model.predict(X) for model in models]
    return np.mean(preds, axis=0)

# 訓練多個 CNN 模型
num_models = 3
models = [Sequential([model.layers[i] for i in range(len(model.layers)-1)]) for _ in range(num_models)]

for model in models:
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    model.fit(X_train, y_train, epochs=5, validation_data=(X_val, y_val))

# 獲取基模型的預測結果
ensemble_preds = ensemble_predict(models, X_val)
ensemble_preds = np.argmax(ensemble_preds, axis=1)

# 計算準確度
ensemble_accuracy = accuracy_score(y_val, ensemble_preds)
print(f'Ensemble Accuracy on Validation Set: {ensemble_accuracy}')

加載了Fashion MNIST數據集,然後構建了一個簡單的卷積神經網絡(CNN)作為基模型。接著,我們訓練了該基模型,並將其覆制多份以構建多個基模型。最後,我們使用簡單平均法進行模型融合,並計算了模型在驗證集上的準確度。


上一篇
DAY 29 「Kaggle 貓狗分類」來做模型融合+超參數調優進行分類啦~
系列文
紮實的ML機器學習原理~打造你對資料使用sklearn的靈敏度30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言